/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.intel.analytics.bigdl.dllib.nn

import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.dllib.utils.{RandomGenerator, T, Table, TestUtils}
import org.scalatest.{FlatSpec, Matchers}

class MaskHeadSpec extends FlatSpec with Matchers {
  "MaskHead" should "be ok" in {
    val inChannels: Int = 6
    val resolution: Int = 14
    val scales: Array[Float] = Array[Float](0.25f, 0.125f)
    val samplingRratio: Int = 2
    val layers: Array[Int] = Array[Int](4, 4)
    val dilation: Int = 1
    val numClasses: Int = 81
    val useGn: Boolean = false

    val layer = new MaskHead(inChannels, resolution, scales,
    samplingRratio, layers, dilation, numClasses, useGn)

    val params = layer.getParameters()
    params._1.fill(0.001f)

    val features1 = Tensor[Float](T(T(T(T(0.5381, 0.0856, 0.1124, 0.7493),
      T(0.4624, 0.2182, 0.7364, 0.3522),
      T(0.7552, 0.7117, 0.2715, 0.9082)),
      T(T(0.0928, 0.2735, 0.7539, 0.7539),
        T(0.4777, 0.1525, 0.8279, 0.6481),
        T(0.6019, 0.4803, 0.5869, 0.7459)),
      T(T(0.1924, 0.2795, 0.4463, 0.3887),
        T(0.5791, 0.9832, 0.8752, 0.4598),
        T(0.2278, 0.0758, 0.4988, 0.3742)),
      T(T(0.1762, 0.6499, 0.2534, 0.9842),
        T(0.0908, 0.8676, 0.1700, 0.1887),
        T(0.7138, 0.9559, 0.0119, 0.7799)),
      T(T(0.8200, 0.6767, 0.3637, 0.9771),
        T(0.1217, 0.5645, 0.2574, 0.6729),
        T(0.6140, 0.5333, 0.4425, 0.1740)),
      T(T(0.3994, 0.9148, 0.0123, 0.0125),
        T(0.5663, 0.9951, 0.8143, 0.9906),
        T(0.0923, 0.8285, 0.2992, 0.2221)))))

    val features2 = Tensor[Float](T(T(T(T(0.0492, 0.1234),
      T(0.3291, 0.0613),
      T(0.4260, 0.1422),
      T(0.2282, 0.4258),
      T(0.7426, 0.9476)),
      T(T(0.6662, 0.7015),
        T(0.4598, 0.6378),
        T(0.9571, 0.4947),
        T(0.1659, 0.3034),
        T(0.8583, 0.1369)),
      T(T(0.1711, 0.6440),
        T(0.2099, 0.4468),
        T(0.9518, 0.3877),
        T(0.4058, 0.6630),
        T(0.9056, 0.4054)),
      T(T(0.4562, 0.0277),
        T(0.2358, 0.3938),
        T(0.9187, 0.4067),
        T(0.0445, 0.4171),
        T(0.3434, 0.1964)),
      T(T(0.9473, 0.7239),
        T(0.1732, 0.5352),
        T(0.8276, 0.6435),
        T(0.3516, 0.3760),
        T(0.3437, 0.0198)),
      T(T(0.7811, 0.5682),
        T(0.5121, 0.9655),
        T(0.3496, 0.7632),
        T(0.4267, 0.4533),
        T(0.8624, 0.3172)))))

    val bbox = Tensor[Float](T(T(1.0f, 3.0f, 2.0f, 6.0f),
      T(3.0f, 5.0f, 6.0f, 10.0f)))
    val labels = Tensor[Float](T(1, 3))

    val output = layer.forward(T(T(features1, features2), T(bbox), labels)).toTable

    val expectedOutput = Tensor[Float](T(T(T(
      T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
      0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
      T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
        0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
      T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

      T(T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
          0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

      T(T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
          0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

      T(T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
          0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013))),


      T(T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
        T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
        T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
          0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

        T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
          0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
          T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
            0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

        T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
          0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
          T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
            0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

        T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
          0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
          T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
            0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)))))

    output[Tensor[Float]](1).almostEqual(expectedOutput, 1e-3) should be(true)
    output[Tensor[Float]](2).apply1(a => {
      a should be(0.5003f +- 1e-3f)
      a
    })
  }

  "MaskHead with batch size > 1" should "be ok" in {
    val inChannels: Int = 6
    val resolution: Int = 14
    val scales: Array[Float] = Array[Float](0.25f, 0.125f)
    val samplingRatio: Int = 2
    val layers: Array[Int] = Array[Int](4, 4)
    val dilation: Int = 1
    val numClasses: Int = 81
    val useGn: Boolean = false

    val layer = new MaskHead(inChannels, resolution, scales,
      samplingRatio, layers, dilation, numClasses, useGn)

    val params = layer.getParameters()
    params._1.fill(0.001f)

    val features1 = Tensor[Float](T(T(T(T(0.5381, 0.0856, 0.1124, 0.7493),
      T(0.4624, 0.2182, 0.7364, 0.3522),
      T(0.7552, 0.7117, 0.2715, 0.9082)),
      T(T(0.0928, 0.2735, 0.7539, 0.7539),
        T(0.4777, 0.1525, 0.8279, 0.6481),
        T(0.6019, 0.4803, 0.5869, 0.7459)),
      T(T(0.1924, 0.2795, 0.4463, 0.3887),
        T(0.5791, 0.9832, 0.8752, 0.4598),
        T(0.2278, 0.0758, 0.4988, 0.3742)),
      T(T(0.1762, 0.6499, 0.2534, 0.9842),
        T(0.0908, 0.8676, 0.1700, 0.1887),
        T(0.7138, 0.9559, 0.0119, 0.7799)),
      T(T(0.8200, 0.6767, 0.3637, 0.9771),
        T(0.1217, 0.5645, 0.2574, 0.6729),
        T(0.6140, 0.5333, 0.4425, 0.1740)),
      T(T(0.3994, 0.9148, 0.0123, 0.0125),
        T(0.5663, 0.9951, 0.8143, 0.9906),
        T(0.0923, 0.8285, 0.2992, 0.2221)))))

    val features2 = Tensor[Float](T(T(T(T(0.0492, 0.1234),
      T(0.3291, 0.0613),
      T(0.4260, 0.1422),
      T(0.2282, 0.4258),
      T(0.7426, 0.9476)),
      T(T(0.6662, 0.7015),
        T(0.4598, 0.6378),
        T(0.9571, 0.4947),
        T(0.1659, 0.3034),
        T(0.8583, 0.1369)),
      T(T(0.1711, 0.6440),
        T(0.2099, 0.4468),
        T(0.9518, 0.3877),
        T(0.4058, 0.6630),
        T(0.9056, 0.4054)),
      T(T(0.4562, 0.0277),
        T(0.2358, 0.3938),
        T(0.9187, 0.4067),
        T(0.0445, 0.4171),
        T(0.3434, 0.1964)),
      T(T(0.9473, 0.7239),
        T(0.1732, 0.5352),
        T(0.8276, 0.6435),
        T(0.3516, 0.3760),
        T(0.3437, 0.0198)),
      T(T(0.7811, 0.5682),
        T(0.5121, 0.9655),
        T(0.3496, 0.7632),
        T(0.4267, 0.4533),
        T(0.8624, 0.3172)))))

    val bbox = Tensor[Float](T(T(1.0f, 3.0f, 2.0f, 6.0f),
      T(3.0f, 5.0f, 6.0f, 10.0f)))
    val labels = Tensor[Float](T(1, 3))

    val features1Batch = Tensor[Float](2, features1.size(2), features1.size(3), features1.size(4))
    features1Batch.select(1, 1).copy(features1)
    features1Batch.select(1, 2).copy(features1)

    val features2Batch = Tensor[Float](2, features2.size(2), features2.size(3), features2.size(4))
    features2Batch.select(1, 1).copy(features2)
    features2Batch.select(1, 2).copy(features2)

    val bboxBatch = T(Tensor[Float](T(T(1.0f, 3.0f, 2.0f, 6.0f), T(3.0f, 5.0f, 6.0f, 7.0f))),
      Tensor[Float](T(T(1.0f, 3.0f, 2.0f, 6.0f), T(3.0f, 5.0f, 6.0f, 7.0f))))

    val labelsBatch = Tensor[Float](T(1, 3, 1, 3))

    val output = layer.forward(T(T(features1Batch, features2Batch), bboxBatch, labelsBatch)).toTable

    val expectedOutput = Tensor[Float](T(T(T(
      T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
      T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
      T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
        0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
      T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
        0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

      T(T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
          0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

      T(T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
          0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

      T(T(0.0013, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0018, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0021, 0.0021, 0.0021, 0.0021, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0021, 0.0021, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,
          0.0015, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013))),


      T(T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
        T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
        T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
          0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
        T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
          0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
        T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
          0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

        T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
          0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
          T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
            0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

        T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
          0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
          T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
            0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)),

        T(T(0.0013, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,
          0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0013),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0018, 0.0018, 0.0017, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0019, 0.0018, 0.0015),
          T(0.0016, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0015),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0018, 0.0019, 0.0019, 0.0020, 0.0020, 0.0020, 0.0020,
            0.0020, 0.0020, 0.0020, 0.0020, 0.0019, 0.0016),
          T(0.0015, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019, 0.0019,
            0.0019, 0.0019, 0.0019, 0.0019, 0.0018, 0.0015),
          T(0.0013, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016,
            0.0016, 0.0016, 0.0016, 0.0016, 0.0015, 0.0013)))))

    output[Tensor[Float]](1).narrow(1, 1, 2).almostEqual(expectedOutput, 1e-3) should be(true)
    output[Tensor[Float]](2).narrow(1, 1, 2).apply1(a => {
      a should be(0.5003f +- 1e-3f)
      a
    })
  }

  "MaskRCNNFPNFeatureExtractor" should "be ok" in {
    val resolution = 14
    val scales = Array[Float](0.25f, 0.125f)
    val sampling_ratio = 2
    val in_channels = 6
    val use_gn = false
    val layers = Array[Int](4, 4)
    val dilation = 1

    val mask = new MaskHead(in_channels, resolution, scales,
      sampling_ratio, layers, dilation, 81, use_gn)
    val layer = mask.maskFeatureExtractor(in_channels, resolution, scales,
      sampling_ratio, layers, dilation, use_gn)

    val paramsTable = layer.getParametersTable()
    for (i <- paramsTable.keySet) {
      val params = paramsTable.get[Table](i).get.get[Tensor[Float]]("weight").get
      params.fill(0.01f)
    }

    val features1 = Tensor[Float](T(T(T(T(0.5381, 0.0856, 0.1124, 0.7493),
      T(0.4624, 0.2182, 0.7364, 0.3522),
      T(0.7552, 0.7117, 0.2715, 0.9082)),
      T(T(0.0928, 0.2735, 0.7539, 0.7539),
        T(0.4777, 0.1525, 0.8279, 0.6481),
        T(0.6019, 0.4803, 0.5869, 0.7459)),
      T(T(0.1924, 0.2795, 0.4463, 0.3887),
        T(0.5791, 0.9832, 0.8752, 0.4598),
        T(0.2278, 0.0758, 0.4988, 0.3742)),
      T(T(0.1762, 0.6499, 0.2534, 0.9842),
        T(0.0908, 0.8676, 0.1700, 0.1887),
        T(0.7138, 0.9559, 0.0119, 0.7799)),
      T(T(0.8200, 0.6767, 0.3637, 0.9771),
        T(0.1217, 0.5645, 0.2574, 0.6729),
        T(0.6140, 0.5333, 0.4425, 0.1740)),
      T(T(0.3994, 0.9148, 0.0123, 0.0125),
        T(0.5663, 0.9951, 0.8143, 0.9906),
        T(0.0923, 0.8285, 0.2992, 0.2221)))))

    val features2 = Tensor[Float](T(T(T(T(0.0492, 0.1234),
      T(0.3291, 0.0613),
      T(0.4260, 0.1422),
      T(0.2282, 0.4258),
      T(0.7426, 0.9476)),
      T(T(0.6662, 0.7015),
        T(0.4598, 0.6378),
        T(0.9571, 0.4947),
        T(0.1659, 0.3034),
        T(0.8583, 0.1369)),
      T(T(0.1711, 0.6440),
        T(0.2099, 0.4468),
        T(0.9518, 0.3877),
        T(0.4058, 0.6630),
        T(0.9056, 0.4054)),
      T(T(0.4562, 0.0277),
        T(0.2358, 0.3938),
        T(0.9187, 0.4067),
        T(0.0445, 0.4171),
        T(0.3434, 0.1964)),
      T(T(0.9473, 0.7239),
        T(0.1732, 0.5352),
        T(0.8276, 0.6435),
        T(0.3516, 0.3760),
        T(0.3437, 0.0198)),
      T(T(0.7811, 0.5682),
        T(0.5121, 0.9655),
        T(0.3496, 0.7632),
        T(0.4267, 0.4533),
        T(0.8624, 0.3172)))))

    val bbox = Tensor[Float](T(T(1.0f, 3.0f, 2.0f, 6.0f),
      T(3.0f, 5.0f, 6.0f, 10.0f)))

    val expectedOutput = Tensor[Float](T(T(T(
      T(0.0224, 0.0359, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404,
      0.0404, 0.0404, 0.0404, 0.0404, 0.0359, 0.0224),
      T(0.0359, 0.0575, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647,
        0.0647, 0.0647, 0.0647, 0.0647, 0.0575, 0.0359),
      T(0.0405, 0.0648, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729,
        0.0729, 0.0729, 0.0729, 0.0729, 0.0648, 0.0405),
      T(0.0406, 0.0649, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730,
        0.0730, 0.0730, 0.0730, 0.0730, 0.0649, 0.0406),
      T(0.0406, 0.0650, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732,
        0.0732, 0.0732, 0.0732, 0.0732, 0.0650, 0.0406),
      T(0.0407, 0.0651, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733,
        0.0733, 0.0733, 0.0733, 0.0733, 0.0651, 0.0407),
      T(0.0408, 0.0653, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734,
        0.0734, 0.0734, 0.0734, 0.0734, 0.0653, 0.0408),
      T(0.0409, 0.0654, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736,
        0.0736, 0.0736, 0.0736, 0.0736, 0.0654, 0.0409),
      T(0.0409, 0.0655, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737,
        0.0737, 0.0737, 0.0737, 0.0737, 0.0655, 0.0409),
      T(0.0410, 0.0656, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738,
        0.0738, 0.0738, 0.0738, 0.0738, 0.0656, 0.0410),
      T(0.0411, 0.0658, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740,
        0.0740, 0.0740, 0.0740, 0.0740, 0.0658, 0.0411),
      T(0.0412, 0.0659, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741,
        0.0741, 0.0741, 0.0741, 0.0741, 0.0659, 0.0412),
      T(0.0366, 0.0586, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660,
        0.0660, 0.0660, 0.0660, 0.0660, 0.0586, 0.0366),
      T(0.0229, 0.0367, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413,
        0.0413, 0.0413, 0.0413, 0.0413, 0.0367, 0.0229)),
      T(T(0.0224, 0.0359, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404,
        0.0404, 0.0404, 0.0404, 0.0404, 0.0359, 0.0224),
        T(0.0359, 0.0575, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647,
          0.0647, 0.0647, 0.0647, 0.0647, 0.0575, 0.0359),
        T(0.0405, 0.0648, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729,
          0.0729, 0.0729, 0.0729, 0.0729, 0.0648, 0.0405),
        T(0.0406, 0.0649, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730,
          0.0730, 0.0730, 0.0730, 0.0730, 0.0649, 0.0406),
        T(0.0406, 0.0650, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732,
          0.0732, 0.0732, 0.0732, 0.0732, 0.0650, 0.0406),
        T(0.0407, 0.0651, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733,
          0.0733, 0.0733, 0.0733, 0.0733, 0.0651, 0.0407),
        T(0.0408, 0.0653, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734,
          0.0734, 0.0734, 0.0734, 0.0734, 0.0653, 0.0408),
        T(0.0409, 0.0654, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736,
          0.0736, 0.0736, 0.0736, 0.0736, 0.0654, 0.0409),
        T(0.0409, 0.0655, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737,
          0.0737, 0.0737, 0.0737, 0.0737, 0.0655, 0.0409),
        T(0.0410, 0.0656, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738,
          0.0738, 0.0738, 0.0738, 0.0738, 0.0656, 0.0410),
        T(0.0411, 0.0658, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740,
          0.0740, 0.0740, 0.0740, 0.0740, 0.0658, 0.0411),
        T(0.0412, 0.0659, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741,
          0.0741, 0.0741, 0.0741, 0.0741, 0.0659, 0.0412),
        T(0.0366, 0.0586, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660,
          0.0660, 0.0660, 0.0660, 0.0660, 0.0586, 0.0366),
        T(0.0229, 0.0367, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413,
          0.0413, 0.0413, 0.0413, 0.0413, 0.0367, 0.0229)),
      T(T(0.0224, 0.0359, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404,
        0.0404, 0.0404, 0.0404, 0.0404, 0.0359, 0.0224),
        T(0.0359, 0.0575, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647,
          0.0647, 0.0647, 0.0647, 0.0647, 0.0575, 0.0359),
        T(0.0405, 0.0648, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729,
          0.0729, 0.0729, 0.0729, 0.0729, 0.0648, 0.0405),
        T(0.0406, 0.0649, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730,
          0.0730, 0.0730, 0.0730, 0.0730, 0.0649, 0.0406),
        T(0.0406, 0.0650, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732,
          0.0732, 0.0732, 0.0732, 0.0732, 0.0650, 0.0406),
        T(0.0407, 0.0651, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733,
          0.0733, 0.0733, 0.0733, 0.0733, 0.0651, 0.0407),
        T(0.0408, 0.0653, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734,
          0.0734, 0.0734, 0.0734, 0.0734, 0.0653, 0.0408),
        T(0.0409, 0.0654, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736,
          0.0736, 0.0736, 0.0736, 0.0736, 0.0654, 0.0409),
        T(0.0409, 0.0655, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737,
          0.0737, 0.0737, 0.0737, 0.0737, 0.0655, 0.0409),
        T(0.0410, 0.0656, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738,
          0.0738, 0.0738, 0.0738, 0.0738, 0.0656, 0.0410),
        T(0.0411, 0.0658, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740,
          0.0740, 0.0740, 0.0740, 0.0740, 0.0658, 0.0411),
        T(0.0412, 0.0659, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741,
          0.0741, 0.0741, 0.0741, 0.0741, 0.0659, 0.0412),
        T(0.0366, 0.0586, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660,
          0.0660, 0.0660, 0.0660, 0.0660, 0.0586, 0.0366),
        T(0.0229, 0.0367, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413,
          0.0413, 0.0413, 0.0413, 0.0413, 0.0367, 0.0229)),
      T(T(0.0224, 0.0359, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404, 0.0404,
        0.0404, 0.0404, 0.0404, 0.0404, 0.0359, 0.0224),
        T(0.0359, 0.0575, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647, 0.0647,
          0.0647, 0.0647, 0.0647, 0.0647, 0.0575, 0.0359),
        T(0.0405, 0.0648, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729, 0.0729,
          0.0729, 0.0729, 0.0729, 0.0729, 0.0648, 0.0405),
        T(0.0406, 0.0649, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730, 0.0730,
          0.0730, 0.0730, 0.0730, 0.0730, 0.0649, 0.0406),
        T(0.0406, 0.0650, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732, 0.0732,
          0.0732, 0.0732, 0.0732, 0.0732, 0.0650, 0.0406),
        T(0.0407, 0.0651, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733, 0.0733,
          0.0733, 0.0733, 0.0733, 0.0733, 0.0651, 0.0407),
        T(0.0408, 0.0653, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734, 0.0734,
          0.0734, 0.0734, 0.0734, 0.0734, 0.0653, 0.0408),
        T(0.0409, 0.0654, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736, 0.0736,
          0.0736, 0.0736, 0.0736, 0.0736, 0.0654, 0.0409),
        T(0.0409, 0.0655, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737, 0.0737,
          0.0737, 0.0737, 0.0737, 0.0737, 0.0655, 0.0409),
        T(0.0410, 0.0656, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738, 0.0738,
          0.0738, 0.0738, 0.0738, 0.0738, 0.0656, 0.0410),
        T(0.0411, 0.0658, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740, 0.0740,
          0.0740, 0.0740, 0.0740, 0.0740, 0.0658, 0.0411),
        T(0.0412, 0.0659, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741, 0.0741,
          0.0741, 0.0741, 0.0741, 0.0741, 0.0659, 0.0412),
        T(0.0366, 0.0586, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660, 0.0660,
          0.0660, 0.0660, 0.0660, 0.0660, 0.0586, 0.0366),
        T(0.0229, 0.0367, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413, 0.0413,
          0.0413, 0.0413, 0.0413, 0.0413, 0.0367, 0.0229))),
      T(T(T(0.0273, 0.0442, 0.0507, 0.0520, 0.0533, 0.0546, 0.0560, 0.0573,
        0.0586, 0.0599, 0.0613, 0.0626, 0.0565, 0.0356),
        T(0.0438, 0.0710, 0.0815, 0.0837, 0.0858, 0.0880, 0.0901, 0.0923,
          0.0944, 0.0965, 0.0987, 0.1008, 0.0911, 0.0575),
        T(0.0498, 0.0806, 0.0925, 0.0950, 0.0974, 0.0999, 0.1023, 0.1047,
          0.1072, 0.1096, 0.1121, 0.1145, 0.1034, 0.0653),
        T(0.0504, 0.0816, 0.0936, 0.0960, 0.0985, 0.1009, 0.1033, 0.1057,
          0.1081, 0.1106, 0.1130, 0.1154, 0.1042, 0.0657),
        T(0.0510, 0.0826, 0.0947, 0.0970, 0.0993, 0.1017, 0.1040, 0.1063,
          0.1086, 0.1110, 0.1133, 0.1156, 0.1043, 0.0658),
        T(0.0517, 0.0836, 0.0957, 0.0979, 0.1001, 0.1023, 0.1045, 0.1067,
          0.1089, 0.1111, 0.1133, 0.1155, 0.1041, 0.0656),
        T(0.0524, 0.0846, 0.0968, 0.0988, 0.1009, 0.1029, 0.1050, 0.1070,
          0.1091, 0.1112, 0.1132, 0.1153, 0.1038, 0.0654),
        T(0.0531, 0.0857, 0.0978, 0.0997, 0.1016, 0.1036, 0.1055, 0.1074,
          0.1093, 0.1112, 0.1131, 0.1151, 0.1036, 0.0652),
        T(0.0537, 0.0867, 0.0988, 0.1006, 0.1024, 0.1042, 0.1060, 0.1077,
          0.1095, 0.1113, 0.1131, 0.1148, 0.1033, 0.0650),
        T(0.0544, 0.0877, 0.0999, 0.1015, 0.1032, 0.1048, 0.1064, 0.1081,
          0.1097, 0.1113, 0.1130, 0.1146, 0.1030, 0.0648),
        T(0.0551, 0.0887, 0.1009, 0.1024, 0.1039, 0.1054, 0.1069, 0.1084,
          0.1099, 0.1114, 0.1129, 0.1144, 0.1027, 0.0646),
        T(0.0557, 0.0897, 0.1020, 0.1033, 0.1047, 0.1060, 0.1074, 0.1088,
          0.1101, 0.1115, 0.1128, 0.1142, 0.1024, 0.0643),
        T(0.0500, 0.0804, 0.0913, 0.0925, 0.0936, 0.0947, 0.0958, 0.0969,
          0.0980, 0.0991, 0.1002, 0.1014, 0.0908, 0.0571),
        T(0.0314, 0.0505, 0.0573, 0.0580, 0.0587, 0.0593, 0.0600, 0.0607,
          0.0613, 0.0620, 0.0626, 0.0633, 0.0567, 0.0356)),
        T(T(0.0273, 0.0442, 0.0507, 0.0520, 0.0533, 0.0546, 0.0560, 0.0573,
          0.0586, 0.0599, 0.0613, 0.0626, 0.0565, 0.0356),
          T(0.0438, 0.0710, 0.0815, 0.0837, 0.0858, 0.0880, 0.0901, 0.0923,
            0.0944, 0.0965, 0.0987, 0.1008, 0.0911, 0.0575),
          T(0.0498, 0.0806, 0.0925, 0.0950, 0.0974, 0.0999, 0.1023, 0.1047,
            0.1072, 0.1096, 0.1121, 0.1145, 0.1034, 0.0653),
          T(0.0504, 0.0816, 0.0936, 0.0960, 0.0985, 0.1009, 0.1033, 0.1057,
            0.1081, 0.1106, 0.1130, 0.1154, 0.1042, 0.0657),
          T(0.0510, 0.0826, 0.0947, 0.0970, 0.0993, 0.1017, 0.1040, 0.1063,
            0.1086, 0.1110, 0.1133, 0.1156, 0.1043, 0.0658),
          T(0.0517, 0.0836, 0.0957, 0.0979, 0.1001, 0.1023, 0.1045, 0.1067,
            0.1089, 0.1111, 0.1133, 0.1155, 0.1041, 0.0656),
          T(0.0524, 0.0846, 0.0968, 0.0988, 0.1009, 0.1029, 0.1050, 0.1070,
            0.1091, 0.1112, 0.1132, 0.1153, 0.1038, 0.0654),
          T(0.0531, 0.0857, 0.0978, 0.0997, 0.1016, 0.1036, 0.1055, 0.1074,
            0.1093, 0.1112, 0.1131, 0.1151, 0.1036, 0.0652),
          T(0.0537, 0.0867, 0.0988, 0.1006, 0.1024, 0.1042, 0.1060, 0.1077,
            0.1095, 0.1113, 0.1131, 0.1148, 0.1033, 0.0650),
          T(0.0544, 0.0877, 0.0999, 0.1015, 0.1032, 0.1048, 0.1064, 0.1081,
            0.1097, 0.1113, 0.1130, 0.1146, 0.1030, 0.0648),
          T(0.0551, 0.0887, 0.1009, 0.1024, 0.1039, 0.1054, 0.1069, 0.1084,
            0.1099, 0.1114, 0.1129, 0.1144, 0.1027, 0.0646),
          T(0.0557, 0.0897, 0.1020, 0.1033, 0.1047, 0.1060, 0.1074, 0.1088,
            0.1101, 0.1115, 0.1128, 0.1142, 0.1024, 0.0643),
          T(0.0500, 0.0804, 0.0913, 0.0925, 0.0936, 0.0947, 0.0958, 0.0969,
            0.0980, 0.0991, 0.1002, 0.1014, 0.0908, 0.0571),
          T(0.0314, 0.0505, 0.0573, 0.0580, 0.0587, 0.0593, 0.0600, 0.0607,
            0.0613, 0.0620, 0.0626, 0.0633, 0.0567, 0.0356)),
        T(T(0.0273, 0.0442, 0.0507, 0.0520, 0.0533, 0.0546, 0.0560, 0.0573,
          0.0586, 0.0599, 0.0613, 0.0626, 0.0565, 0.0356),
          T(0.0438, 0.0710, 0.0815, 0.0837, 0.0858, 0.0880, 0.0901, 0.0923,
            0.0944, 0.0965, 0.0987, 0.1008, 0.0911, 0.0575),
          T(0.0498, 0.0806, 0.0925, 0.0950, 0.0974, 0.0999, 0.1023, 0.1047,
            0.1072, 0.1096, 0.1121, 0.1145, 0.1034, 0.0653),
          T(0.0504, 0.0816, 0.0936, 0.0960, 0.0985, 0.1009, 0.1033, 0.1057,
            0.1081, 0.1106, 0.1130, 0.1154, 0.1042, 0.0657),
          T(0.0510, 0.0826, 0.0947, 0.0970, 0.0993, 0.1017, 0.1040, 0.1063,
            0.1086, 0.1110, 0.1133, 0.1156, 0.1043, 0.0658),
          T(0.0517, 0.0836, 0.0957, 0.0979, 0.1001, 0.1023, 0.1045, 0.1067,
            0.1089, 0.1111, 0.1133, 0.1155, 0.1041, 0.0656),
          T(0.0524, 0.0846, 0.0968, 0.0988, 0.1009, 0.1029, 0.1050, 0.1070,
            0.1091, 0.1112, 0.1132, 0.1153, 0.1038, 0.0654),
          T(0.0531, 0.0857, 0.0978, 0.0997, 0.1016, 0.1036, 0.1055, 0.1074,
            0.1093, 0.1112, 0.1131, 0.1151, 0.1036, 0.0652),
          T(0.0537, 0.0867, 0.0988, 0.1006, 0.1024, 0.1042, 0.1060, 0.1077,
            0.1095, 0.1113, 0.1131, 0.1148, 0.1033, 0.0650),
          T(0.0544, 0.0877, 0.0999, 0.1015, 0.1032, 0.1048, 0.1064, 0.1081,
            0.1097, 0.1113, 0.1130, 0.1146, 0.1030, 0.0648),
          T(0.0551, 0.0887, 0.1009, 0.1024, 0.1039, 0.1054, 0.1069, 0.1084,
            0.1099, 0.1114, 0.1129, 0.1144, 0.1027, 0.0646),
          T(0.0557, 0.0897, 0.1020, 0.1033, 0.1047, 0.1060, 0.1074, 0.1088,
            0.1101, 0.1115, 0.1128, 0.1142, 0.1024, 0.0643),
          T(0.0500, 0.0804, 0.0913, 0.0925, 0.0936, 0.0947, 0.0958, 0.0969,
            0.0980, 0.0991, 0.1002, 0.1014, 0.0908, 0.0571),
          T(0.0314, 0.0505, 0.0573, 0.0580, 0.0587, 0.0593, 0.0600, 0.0607,
            0.0613, 0.0620, 0.0626, 0.0633, 0.0567, 0.0356)),
        T(T(0.0273, 0.0442, 0.0507, 0.0520, 0.0533, 0.0546, 0.0560, 0.0573,
          0.0586, 0.0599, 0.0613, 0.0626, 0.0565, 0.0356),
          T(0.0438, 0.0710, 0.0815, 0.0837, 0.0858, 0.0880, 0.0901, 0.0923,
            0.0944, 0.0965, 0.0987, 0.1008, 0.0911, 0.0575),
          T(0.0498, 0.0806, 0.0925, 0.0950, 0.0974, 0.0999, 0.1023, 0.1047,
            0.1072, 0.1096, 0.1121, 0.1145, 0.1034, 0.0653),
          T(0.0504, 0.0816, 0.0936, 0.0960, 0.0985, 0.1009, 0.1033, 0.1057,
            0.1081, 0.1106, 0.1130, 0.1154, 0.1042, 0.0657),
          T(0.0510, 0.0826, 0.0947, 0.0970, 0.0993, 0.1017, 0.1040, 0.1063,
            0.1086, 0.1110, 0.1133, 0.1156, 0.1043, 0.0658),
          T(0.0517, 0.0836, 0.0957, 0.0979, 0.1001, 0.1023, 0.1045, 0.1067,
            0.1089, 0.1111, 0.1133, 0.1155, 0.1041, 0.0656),
          T(0.0524, 0.0846, 0.0968, 0.0988, 0.1009, 0.1029, 0.1050, 0.1070,
            0.1091, 0.1112, 0.1132, 0.1153, 0.1038, 0.0654),
          T(0.0531, 0.0857, 0.0978, 0.0997, 0.1016, 0.1036, 0.1055, 0.1074,
            0.1093, 0.1112, 0.1131, 0.1151, 0.1036, 0.0652),
          T(0.0537, 0.0867, 0.0988, 0.1006, 0.1024, 0.1042, 0.1060, 0.1077,
            0.1095, 0.1113, 0.1131, 0.1148, 0.1033, 0.0650),
          T(0.0544, 0.0877, 0.0999, 0.1015, 0.1032, 0.1048, 0.1064, 0.1081,
            0.1097, 0.1113, 0.1130, 0.1146, 0.1030, 0.0648),
          T(0.0551, 0.0887, 0.1009, 0.1024, 0.1039, 0.1054, 0.1069, 0.1084,
            0.1099, 0.1114, 0.1129, 0.1144, 0.1027, 0.0646),
          T(0.0557, 0.0897, 0.1020, 0.1033, 0.1047, 0.1060, 0.1074, 0.1088,
            0.1101, 0.1115, 0.1128, 0.1142, 0.1024, 0.0643),
          T(0.0500, 0.0804, 0.0913, 0.0925, 0.0936, 0.0947, 0.0958, 0.0969,
            0.0980, 0.0991, 0.1002, 0.1014, 0.0908, 0.0571),
          T(0.0314, 0.0505, 0.0573, 0.0580, 0.0587, 0.0593, 0.0600, 0.0607,
            0.0613, 0.0620, 0.0626, 0.0633, 0.0567, 0.0356)))))

    val output = layer.forward(T(T(features1, features2), T(bbox))).toTensor[Float]

    output.almostEqual(expectedOutput, 1e-3) should be(true)
  }

  "MaskRCNNC4Predictor" should "be ok" in {
    RandomGenerator.RNG.setSeed(100)

    val input = Tensor[Float](T(T(T(
      T(0.1117, 0.8158, 0.2626, 0.4839, 0.6765, 0.7539, 0.2627, 0.0428),
      T(0.2080, 0.1180, 0.1217, 0.7356, 0.7118, 0.7876, 0.4183, 0.9014),
      T(0.9969, 0.7565, 0.2239, 0.3023, 0.1784, 0.8238, 0.5557, 0.9770),
      T(0.4440, 0.9478, 0.7445, 0.4892, 0.2426, 0.7003, 0.5277, 0.2472),
      T(0.7909, 0.4235, 0.0169, 0.2209, 0.9535, 0.7064, 0.1629, 0.8902),
      T(0.5163, 0.0359, 0.6476, 0.3430, 0.3182, 0.5261, 0.0447, 0.5123),
      T(0.9051, 0.5989, 0.4450, 0.7278, 0.4563, 0.3389, 0.6211, 0.5530),
      T(0.6896, 0.3687, 0.9053, 0.8356, 0.3039, 0.6726, 0.5740, 0.9233)),
      T(T(0.9178, 0.7590, 0.7775, 0.6179, 0.3379, 0.2170, 0.9454, 0.7116),
        T(0.1157, 0.6574, 0.3451, 0.0453, 0.9798, 0.5548, 0.6868, 0.4920),
        T(0.0748, 0.9605, 0.3271, 0.0103, 0.9516, 0.2855, 0.2324, 0.9141),
        T(0.7668, 0.1659, 0.4393, 0.2243, 0.8935, 0.0497, 0.1780, 0.3011),
        T(0.1893, 0.9186, 0.2131, 0.3957, 0.6017, 0.4234, 0.5224, 0.4175),
        T(0.0340, 0.9157, 0.3079, 0.6269, 0.8277, 0.6594, 0.0887, 0.4890),
        T(0.5887, 0.7340, 0.8497, 0.9112, 0.4847, 0.9436, 0.3904, 0.2499),
        T(0.3206, 0.9753, 0.7582, 0.6688, 0.2651, 0.2336, 0.5057, 0.5688)),
      T(T(0.0634, 0.8993, 0.2732, 0.3397, 0.1879, 0.5534, 0.2682, 0.9556),
        T(0.9761, 0.5934, 0.3124, 0.9431, 0.8519, 0.9815, 0.1132, 0.4783),
        T(0.4436, 0.3847, 0.4521, 0.5569, 0.9952, 0.0015, 0.0813, 0.4907),
        T(0.2130, 0.4603, 0.1386, 0.0277, 0.5662, 0.3503, 0.6555, 0.7667),
        T(0.2269, 0.7555, 0.6458, 0.3673, 0.1770, 0.2966, 0.9925, 0.2103),
        T(0.1292, 0.1719, 0.9127, 0.6818, 0.1953, 0.9991, 0.1133, 0.0135),
        T(0.1450, 0.7819, 0.3134, 0.2983, 0.3436, 0.2028, 0.9792, 0.4947),
        T(0.3617, 0.9687, 0.0359, 0.3041, 0.9867, 0.1290, 0.6887, 0.1637))),
      T(T(T(0.0899, 0.3139, 0.1219, 0.3516, 0.2316, 0.2847, 0.3520, 0.2828),
        T(0.2420, 0.4928, 0.5772, 0.3771, 0.2440, 0.8994, 0.1041, 0.9193),
        T(0.6201, 0.3658, 0.0623, 0.5967, 0.0829, 0.8185, 0.4964, 0.0589),
        T(0.9840, 0.5836, 0.6737, 0.4738, 0.9336, 0.2557, 0.1506, 0.7856),
        T(0.4152, 0.5809, 0.1088, 0.7065, 0.0105, 0.4602, 0.2945, 0.0475),
        T(0.6401, 0.3784, 0.5887, 0.0720, 0.9140, 0.0085, 0.2174, 0.1890),
        T(0.0911, 0.6344, 0.3142, 0.7052, 0.6447, 0.9517, 0.3581, 0.3411),
        T(0.0433, 0.4373, 0.9947, 0.1748, 0.1374, 0.8005, 0.7004, 0.8803)),
        T(T(0.1573, 0.3343, 0.9652, 0.1862, 0.1508, 0.3183, 0.0321, 0.3290),
          T(0.5301, 0.6401, 0.7954, 0.3066, 0.2397, 0.1156, 0.4839, 0.3944),
          T(0.0801, 0.7782, 0.6686, 0.2312, 0.1164, 0.1921, 0.2380, 0.1643),
          T(0.1724, 0.8462, 0.1072, 0.7113, 0.1406, 0.2950, 0.3264, 0.4708),
          T(0.3978, 0.7055, 0.9162, 0.8060, 0.7267, 0.8054, 0.1696, 0.2023),
          T(0.9194, 0.0151, 0.0324, 0.9538, 0.5564, 0.7567, 0.1573, 0.3969),
          T(0.2381, 0.1268, 0.4460, 0.0370, 0.6442, 0.8108, 0.2550, 0.8608),
          T(0.8250, 0.2236, 0.0772, 0.4818, 0.0776, 0.0531, 0.2610, 0.1068)),
        T(T(0.3011, 0.4587, 0.5222, 0.0683, 0.9118, 0.8286, 0.1635, 0.1775),
          T(0.7163, 0.9355, 0.1430, 0.3933, 0.1124, 0.3087, 0.9973, 0.4257),
          T(0.6890, 0.9657, 0.0257, 0.4205, 0.0656, 0.4508, 0.0553, 0.3140),
          T(0.7460, 0.9357, 0.8925, 0.1370, 0.1803, 0.4023, 0.4296, 0.3692),
          T(0.1611, 0.9422, 0.8777, 0.5321, 0.5392, 0.1580, 0.6420, 0.6931),
          T(0.0031, 0.6751, 0.1537, 0.5281, 0.1162, 0.4431, 0.2135, 0.2118),
          T(0.6561, 0.3722, 0.3653, 0.7055, 0.0839, 0.1767, 0.7989, 0.9738),
          T(0.2665, 0.1409, 0.7630, 0.9691, 0.3708, 0.0624, 0.5867, 0.7174)))))

    val in_channels: Int = 3
    val num_classes: Int = 2
    val dim_reduced: Int = 10
    val resolution: Int = 14
    val scales: Array[Float] = Array[Float](0.25f, 0.125f)
    val samplingRratio: Int = 2
    val layers: Array[Int] = Array[Int](4, 4)
    val dilation: Int = 1
    val numClasses: Int = 81
    val useGn: Boolean = false

    val mask = new MaskHead(in_channels, resolution, scales,
      samplingRratio, layers, dilation, numClasses, useGn)
    val layer = mask.maskPredictor(in_channels, num_classes, dim_reduced)
    val params = layer.getParameters()
    params._1.fill(0.01f)

    val output = layer.forward(input).toTensor[Float]

    TestUtils.conditionFailTest(output.size(1) == 2 && output.size(2) == 2 &&
      output.size(3) == 16 && output.size(4) == 16)

    val expectedOutput = Tensor[Float](T(T(T(
      T(0.0121, 0.0121, 0.0135, 0.0135, 0.0123, 0.0123, 0.0124, 0.0124,
        0.0122, 0.0122, 0.0125, 0.0125, 0.0125, 0.0125, 0.0127, 0.0127),
      T(0.0121, 0.0121, 0.0135, 0.0135, 0.0123, 0.0123, 0.0124, 0.0124,
        0.0122, 0.0122, 0.0125, 0.0125, 0.0125, 0.0125, 0.0127, 0.0127),
      T(0.0123, 0.0123, 0.0124, 0.0124, 0.0118, 0.0118, 0.0127, 0.0127,
        0.0135, 0.0135, 0.0133, 0.0133, 0.0122, 0.0122, 0.0129, 0.0129),
      T(0.0123, 0.0123, 0.0124, 0.0124, 0.0118, 0.0118, 0.0127, 0.0127,
        0.0135, 0.0135, 0.0133, 0.0133, 0.0122, 0.0122, 0.0129, 0.0129),
      T(0.0125, 0.0125, 0.0131, 0.0131, 0.0120, 0.0120, 0.0119, 0.0119,
        0.0131, 0.0131, 0.0121, 0.0121, 0.0119, 0.0119, 0.0134, 0.0134),
      T(0.0125, 0.0125, 0.0131, 0.0131, 0.0120, 0.0120, 0.0119, 0.0119,
        0.0131, 0.0131, 0.0121, 0.0121, 0.0119, 0.0119, 0.0134, 0.0134),
      T(0.0124, 0.0124, 0.0126, 0.0126, 0.0123, 0.0123, 0.0117, 0.0117,
        0.0127, 0.0127, 0.0121, 0.0121, 0.0124, 0.0124, 0.0123, 0.0123),
      T(0.0124, 0.0124, 0.0126, 0.0126, 0.0123, 0.0123, 0.0117, 0.0117,
        0.0127, 0.0127, 0.0121, 0.0121, 0.0124, 0.0124, 0.0123, 0.0123),
      T(0.0122, 0.0122, 0.0131, 0.0131, 0.0119, 0.0119, 0.0120, 0.0120,
        0.0127, 0.0127, 0.0124, 0.0124, 0.0127, 0.0127, 0.0125, 0.0125),
      T(0.0122, 0.0122, 0.0131, 0.0131, 0.0119, 0.0119, 0.0120, 0.0120,
        0.0127, 0.0127, 0.0124, 0.0124, 0.0127, 0.0127, 0.0125, 0.0125),
      T(0.0117, 0.0117, 0.0121, 0.0121, 0.0129, 0.0129, 0.0127, 0.0127,
        0.0123, 0.0123, 0.0132, 0.0132, 0.0112, 0.0112, 0.0120, 0.0120),
      T(0.0117, 0.0117, 0.0121, 0.0121, 0.0129, 0.0129, 0.0127, 0.0127,
        0.0123, 0.0123, 0.0132, 0.0132, 0.0112, 0.0112, 0.0120, 0.0120),
      T(0.0126, 0.0126, 0.0131, 0.0131, 0.0126, 0.0126, 0.0129, 0.0129,
        0.0123, 0.0123, 0.0125, 0.0125, 0.0130, 0.0130, 0.0123, 0.0123),
      T(0.0126, 0.0126, 0.0131, 0.0131, 0.0126, 0.0126, 0.0129, 0.0129,
        0.0123, 0.0123, 0.0125, 0.0125, 0.0130, 0.0130, 0.0123, 0.0123),
      T(0.0124, 0.0124, 0.0133, 0.0133, 0.0127, 0.0127, 0.0128, 0.0128,
        0.0126, 0.0126, 0.0120, 0.0120, 0.0128, 0.0128, 0.0127, 0.0127),
      T(0.0124, 0.0124, 0.0133, 0.0133, 0.0127, 0.0127, 0.0128, 0.0128,
        0.0126, 0.0126, 0.0120, 0.0120, 0.0128, 0.0128, 0.0127, 0.0127)),
      T(T(0.0121, 0.0121, 0.0135, 0.0135, 0.0123, 0.0123, 0.0124, 0.0124,
        0.0122, 0.0122, 0.0125, 0.0125, 0.0125, 0.0125, 0.0127, 0.0127),
        T(0.0121, 0.0121, 0.0135, 0.0135, 0.0123, 0.0123, 0.0124, 0.0124,
          0.0122, 0.0122, 0.0125, 0.0125, 0.0125, 0.0125, 0.0127, 0.0127),
        T(0.0123, 0.0123, 0.0124, 0.0124, 0.0118, 0.0118, 0.0127, 0.0127,
          0.0135, 0.0135, 0.0133, 0.0133, 0.0122, 0.0122, 0.0129, 0.0129),
        T(0.0123, 0.0123, 0.0124, 0.0124, 0.0118, 0.0118, 0.0127, 0.0127,
          0.0135, 0.0135, 0.0133, 0.0133, 0.0122, 0.0122, 0.0129, 0.0129),
        T(0.0125, 0.0125, 0.0131, 0.0131, 0.0120, 0.0120, 0.0119, 0.0119,
          0.0131, 0.0131, 0.0121, 0.0121, 0.0119, 0.0119, 0.0134, 0.0134),
        T(0.0125, 0.0125, 0.0131, 0.0131, 0.0120, 0.0120, 0.0119, 0.0119,
          0.0131, 0.0131, 0.0121, 0.0121, 0.0119, 0.0119, 0.0134, 0.0134),
        T(0.0124, 0.0124, 0.0126, 0.0126, 0.0123, 0.0123, 0.0117, 0.0117,
          0.0127, 0.0127, 0.0121, 0.0121, 0.0124, 0.0124, 0.0123, 0.0123),
        T(0.0124, 0.0124, 0.0126, 0.0126, 0.0123, 0.0123, 0.0117, 0.0117,
          0.0127, 0.0127, 0.0121, 0.0121, 0.0124, 0.0124, 0.0123, 0.0123),
        T(0.0122, 0.0122, 0.0131, 0.0131, 0.0119, 0.0119, 0.0120, 0.0120,
          0.0127, 0.0127, 0.0124, 0.0124, 0.0127, 0.0127, 0.0125, 0.0125),
        T(0.0122, 0.0122, 0.0131, 0.0131, 0.0119, 0.0119, 0.0120, 0.0120,
          0.0127, 0.0127, 0.0124, 0.0124, 0.0127, 0.0127, 0.0125, 0.0125),
        T(0.0117, 0.0117, 0.0121, 0.0121, 0.0129, 0.0129, 0.0127, 0.0127,
          0.0123, 0.0123, 0.0132, 0.0132, 0.0112, 0.0112, 0.0120, 0.0120),
        T(0.0117, 0.0117, 0.0121, 0.0121, 0.0129, 0.0129, 0.0127, 0.0127,
          0.0123, 0.0123, 0.0132, 0.0132, 0.0112, 0.0112, 0.0120, 0.0120),
        T(0.0126, 0.0126, 0.0131, 0.0131, 0.0126, 0.0126, 0.0129, 0.0129,
          0.0123, 0.0123, 0.0125, 0.0125, 0.0130, 0.0130, 0.0123, 0.0123),
        T(0.0126, 0.0126, 0.0131, 0.0131, 0.0126, 0.0126, 0.0129, 0.0129,
          0.0123, 0.0123, 0.0125, 0.0125, 0.0130, 0.0130, 0.0123, 0.0123),
        T(0.0124, 0.0124, 0.0133, 0.0133, 0.0127, 0.0127, 0.0128, 0.0128,
          0.0126, 0.0126, 0.0120, 0.0120, 0.0128, 0.0128, 0.0127, 0.0127),
        T(0.0124, 0.0124, 0.0133, 0.0133, 0.0127, 0.0127, 0.0128, 0.0128,
          0.0126, 0.0126, 0.0120, 0.0120, 0.0128, 0.0128, 0.0127, 0.0127))),
      T(T(T(0.0115, 0.0115, 0.0121, 0.0121, 0.0126, 0.0126, 0.0116, 0.0116,
        0.0123, 0.0123, 0.0124, 0.0124, 0.0115, 0.0115, 0.0118, 0.0118),
        T(0.0115, 0.0115, 0.0121, 0.0121, 0.0126, 0.0126, 0.0116, 0.0116,
          0.0123, 0.0123, 0.0124, 0.0124, 0.0115, 0.0115, 0.0118, 0.0118),
        T(0.0125, 0.0125, 0.0131, 0.0131, 0.0125, 0.0125, 0.0121, 0.0121,
          0.0116, 0.0116, 0.0123, 0.0123, 0.0126, 0.0126, 0.0127, 0.0127),
        T(0.0125, 0.0125, 0.0131, 0.0131, 0.0125, 0.0125, 0.0121, 0.0121,
          0.0116, 0.0116, 0.0123, 0.0123, 0.0126, 0.0126, 0.0127, 0.0127),
        T(0.0124, 0.0124, 0.0131, 0.0131, 0.0118, 0.0118, 0.0122, 0.0122,
          0.0113, 0.0113, 0.0125, 0.0125, 0.0118, 0.0118, 0.0115, 0.0115),
        T(0.0124, 0.0124, 0.0131, 0.0131, 0.0118, 0.0118, 0.0122, 0.0122,
          0.0113, 0.0113, 0.0125, 0.0125, 0.0118, 0.0118, 0.0115, 0.0115),
        T(0.0129, 0.0129, 0.0134, 0.0134, 0.0127, 0.0127, 0.0123, 0.0123,
          0.0123, 0.0123, 0.0120, 0.0120, 0.0119, 0.0119, 0.0126, 0.0126),
        T(0.0129, 0.0129, 0.0134, 0.0134, 0.0127, 0.0127, 0.0123, 0.0123,
          0.0123, 0.0123, 0.0120, 0.0120, 0.0119, 0.0119, 0.0126, 0.0126),
        T(0.0120, 0.0120, 0.0132, 0.0132, 0.0129, 0.0129, 0.0130, 0.0130,
          0.0123, 0.0123, 0.0124, 0.0124, 0.0121, 0.0121, 0.0119, 0.0119),
        T(0.0120, 0.0120, 0.0132, 0.0132, 0.0129, 0.0129, 0.0130, 0.0130,
          0.0123, 0.0123, 0.0124, 0.0124, 0.0121, 0.0121, 0.0119, 0.0119),
        T(0.0126, 0.0126, 0.0121, 0.0121, 0.0118, 0.0118, 0.0126, 0.0126,
          0.0126, 0.0126, 0.0122, 0.0122, 0.0116, 0.0116, 0.0118, 0.0118),
        T(0.0126, 0.0126, 0.0121, 0.0121, 0.0118, 0.0118, 0.0126, 0.0126,
          0.0126, 0.0126, 0.0122, 0.0122, 0.0116, 0.0116, 0.0118, 0.0118),
        T(0.0120, 0.0120, 0.0121, 0.0121, 0.0121, 0.0121, 0.0124, 0.0124,
          0.0124, 0.0124, 0.0129, 0.0129, 0.0124, 0.0124, 0.0132, 0.0132),
        T(0.0120, 0.0120, 0.0121, 0.0121, 0.0121, 0.0121, 0.0124, 0.0124,
          0.0124, 0.0124, 0.0129, 0.0129, 0.0124, 0.0124, 0.0132, 0.0132),
        T(0.0121, 0.0121, 0.0118, 0.0118, 0.0128, 0.0128, 0.0126, 0.0126,
          0.0116, 0.0116, 0.0119, 0.0119, 0.0125, 0.0125, 0.0127, 0.0127),
        T(0.0121, 0.0121, 0.0118, 0.0118, 0.0128, 0.0128, 0.0126, 0.0126,
          0.0116, 0.0116, 0.0119, 0.0119, 0.0125, 0.0125, 0.0127, 0.0127)),
        T(T(0.0115, 0.0115, 0.0121, 0.0121, 0.0126, 0.0126, 0.0116, 0.0116,
          0.0123, 0.0123, 0.0124, 0.0124, 0.0115, 0.0115, 0.0118, 0.0118),
          T(0.0115, 0.0115, 0.0121, 0.0121, 0.0126, 0.0126, 0.0116, 0.0116,
            0.0123, 0.0123, 0.0124, 0.0124, 0.0115, 0.0115, 0.0118, 0.0118),
          T(0.0125, 0.0125, 0.0131, 0.0131, 0.0125, 0.0125, 0.0121, 0.0121,
            0.0116, 0.0116, 0.0123, 0.0123, 0.0126, 0.0126, 0.0127, 0.0127),
          T(0.0125, 0.0125, 0.0131, 0.0131, 0.0125, 0.0125, 0.0121, 0.0121,
            0.0116, 0.0116, 0.0123, 0.0123, 0.0126, 0.0126, 0.0127, 0.0127),
          T(0.0124, 0.0124, 0.0131, 0.0131, 0.0118, 0.0118, 0.0122, 0.0122,
            0.0113, 0.0113, 0.0125, 0.0125, 0.0118, 0.0118, 0.0115, 0.0115),
          T(0.0124, 0.0124, 0.0131, 0.0131, 0.0118, 0.0118, 0.0122, 0.0122,
            0.0113, 0.0113, 0.0125, 0.0125, 0.0118, 0.0118, 0.0115, 0.0115),
          T(0.0129, 0.0129, 0.0134, 0.0134, 0.0127, 0.0127, 0.0123, 0.0123,
            0.0123, 0.0123, 0.0120, 0.0120, 0.0119, 0.0119, 0.0126, 0.0126),
          T(0.0129, 0.0129, 0.0134, 0.0134, 0.0127, 0.0127, 0.0123, 0.0123,
            0.0123, 0.0123, 0.0120, 0.0120, 0.0119, 0.0119, 0.0126, 0.0126),
          T(0.0120, 0.0120, 0.0132, 0.0132, 0.0129, 0.0129, 0.0130, 0.0130,
            0.0123, 0.0123, 0.0124, 0.0124, 0.0121, 0.0121, 0.0119, 0.0119),
          T(0.0120, 0.0120, 0.0132, 0.0132, 0.0129, 0.0129, 0.0130, 0.0130,
            0.0123, 0.0123, 0.0124, 0.0124, 0.0121, 0.0121, 0.0119, 0.0119),
          T(0.0126, 0.0126, 0.0121, 0.0121, 0.0118, 0.0118, 0.0126, 0.0126,
            0.0126, 0.0126, 0.0122, 0.0122, 0.0116, 0.0116, 0.0118, 0.0118),
          T(0.0126, 0.0126, 0.0121, 0.0121, 0.0118, 0.0118, 0.0126, 0.0126,
            0.0126, 0.0126, 0.0122, 0.0122, 0.0116, 0.0116, 0.0118, 0.0118),
          T(0.0120, 0.0120, 0.0121, 0.0121, 0.0121, 0.0121, 0.0124, 0.0124,
            0.0124, 0.0124, 0.0129, 0.0129, 0.0124, 0.0124, 0.0132, 0.0132),
          T(0.0120, 0.0120, 0.0121, 0.0121, 0.0121, 0.0121, 0.0124, 0.0124,
            0.0124, 0.0124, 0.0129, 0.0129, 0.0124, 0.0124, 0.0132, 0.0132),
          T(0.0121, 0.0121, 0.0118, 0.0118, 0.0128, 0.0128, 0.0126, 0.0126,
            0.0116, 0.0116, 0.0119, 0.0119, 0.0125, 0.0125, 0.0127, 0.0127),
          T(0.0121, 0.0121, 0.0118, 0.0118, 0.0128, 0.0128, 0.0126, 0.0126,
            0.0116, 0.0116, 0.0119, 0.0119, 0.0125, 0.0125, 0.0127, 0.0127)))))

    output.almostEqual(expectedOutput, 1e-4) should be(true)
  }
}

class MaskHeadSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val inChannels: Int = 6
    val resolution: Int = 14
    val scales: Array[Float] = Array[Float](0.25f, 0.125f)
    val samplingRratio: Int = 2
    val layers: Array[Int] = Array[Int](4, 4)
    val dilation: Int = 1
    val numClasses: Int = 81
    val useGn: Boolean = false

    val layer = new MaskHead(inChannels, resolution, scales,
      samplingRratio, layers, dilation, numClasses, useGn).setName("MaskHead")

    val features1 = Tensor[Float](1, 6, 3, 4).rand()
    val features2 = Tensor[Float](1, 6, 5, 2).rand()

    val bbox = Tensor[Float](T(T(1.0f, 3.0f, 2.0f, 6.0f),
      T(3.0f, 5.0f, 6.0f, 10.0f)))
    val labels = Tensor[Float](T(1, 3))

    runSerializationTest(layer, T(T(features1, features2), T(bbox), labels))
  }
}
